I'm going to begin this lab by showing you a YouTube video that I found extremely helpful in understanding the EM algorithm.
Some important things to note:
# load packages
library(pacman)
p_load(ggplot2,
gganimate,
tidyverse,
knitr,
kableExtra,
readr,
janitor,
lme4, # for fitting linear mixed models
lattice) # for plotting random effectsNow, let's consider an EM algorithm for a simple case where we have a mixture of two Poisson distributions (so we don't have to estimate so many parameters). That is, suppose we know our data come from two Poisson distributions, but we don't know the means of these distributions, nor do we know the corresponding proportions/probabilities that our data are drawn from each of the distributions.
Let's simulate some data. Choose \(\lambda_1 = 3\) and \(\lambda_2 = 8\) with mixing probabilities \(\frac{1}{6}\) and \(\frac{5}{6}\). Then our data will look something like this.
set.seed(123)
# proportions for each distirbution
n1 = rbinom(n = 1, size = 1000, p = 1/6)
n2 = 1000-n1
set.seed(123)
# simulate with corresponding probability/proportion
sim_mixture = c(rpois(n = n1, lambda = 3), rpois(n = n2, lambda = 8))
data.frame(x = sim_mixture, mixture = as.character(c(rep("Sample 1",n1),rep("Sample 2",n2))))%>%
bind_rows(data.frame(x = sim_mixture, mixture = "Mixture" ))%>%
ggplot(aes(x = x, fill = mixture))+
geom_histogram(binwidth = 1)+
theme_bw()+
facet_wrap(.~mixture, nrow = 2)+
scale_fill_viridis_d()Now, given our data we can also write down our likelihood with respect to our parameter vector \(\pmb\theta = (\theta_1, \theta_2, \lambda_1, \lambda_2).\)
Now we begin the steps of the EM algorithm by first defining our \(Q\) function.
Notice that the only parts that depend on \(u\) are the indicators. Everything else will be unchanged by taking the expectation, since \(\pmb{y}\) and \(\pmb\theta\) are taken to be given in this step.
So we can plug these back into our \(Q\) function and since nothing else depends on \(u\) now (has been integrated out) we can move to maximizing \(Q\) over our unknown variables.
Now we take derivative with respect to our various parameters to maximize \(Q\).
Maximizing with respect to \(\theta_1\) and \(\theta_2\):
Maximizing with respect to \(\lambda_1\) and \(\lambda_2\):
So these are our estimates for \(\theta_1,\theta_2,\lambda_1, \text{ and }\lambda_2\) at each iteration.
Now, let's implement this using our data. First we must set initial values of our unknown parameters. Let's begin with \[\theta_0 = (\frac{1}{2},\frac{1}{2},\bar{y}-1,\bar{y}+1) \]
At each iteration, we will keep track of:
What is the observed data likelihood at any given iteration?
run_EM = FALSE
if(run_EM){
# make data frame to save theta values
theta = data.frame(th1 = .5, th2 = .5, l1 = round(mean(sim_mixture)-1), l2 = round(mean(sim_mixture)+1))
# keep track of Q functions
Q_fun = c()
# keep track of observed log likelihoods
ll = c()
u = data.frame(iteration = NULL, sample_point = NULL, posterior_prob = NULL)
# check for convergence
delta = 10
while(abs(delta)>1e-20){
i = nrow(theta)
y = sim_mixture
th1 = theta$th1[i]
th2 = theta$th2[i]
l1 = theta$l1[i]
l2 = theta$l2[i]
# P(u_i = 1 | y, theta)
ui1 = function(x){
exp(-l1)*l1^x*th1/(exp(-l1)*l1^x*th1+exp(-l2)*l2^x*th2)
}
# P(u_i = 2 | y, theta)
ui2 = function(x){
exp(-l2)*l2^x*th2/(exp(-l1)*l1^x*th1+exp(-l2)*l2^x*th2)
}
# f1 at current lambda estimates
f1 = function(x){
exp(-l1)*l1^x/factorial(x)
}
# f2 at current lambda estimates
f2 = function(x){
exp(-l2)*l2^x/factorial(x)
}
u = bind_rows(u,data.frame(iteration = i, sample_point = unique(y), posterior_prob = sapply(unique(y), FUN = ui1)))
# update thetas with mean ui1 and ui2
th1_new = mean(sapply(y, ui1))
th2_new = mean(sapply(y, ui2))
l1_new = mean(y*sapply(y, ui1))/th1_new
l2_new = mean(y*sapply(y, ui2))/th2_new
# f1 at current lambda estimates
f1_new = function(x){
exp(-l1_new)*l1_new^x/factorial(x)
}
# f2 at current lambda estimates
f2_new = function(x){
exp(-l2_new)*l2_new^x/factorial(x)
}
# Q value at new theta
Q_new = sum(sapply(y, FUN = ui1)*(log(th1_new)+log(sapply(y, f1_new)))+
sapply(y, FUN = ui2)*(log(th2_new)+log(sapply(y, f2_new))))
# update Q
Q_fun = c(Q_fun, Q_new)
# update theta
theta = bind_rows(theta, data.frame(th1 = th1_new, th2 = th2_new, l1 = l1_new, l2 = l2_new))
# update delta with |theta^{t+1}-theta^{t}|/|theta^{t}|
delta = norm(matrix(as.numeric(theta[i,])-as.numeric(theta[(i+1),])))/norm(matrix(as.numeric(theta[i,])))
# calculate observed log likelihood
obs_ll = sum(log(th1_new*dpois(y,l1_new)+th2_new*dpois(y, l2_new)))
# update observed log likelihood
ll = c(ll, obs_ll)
}
saveRDS(theta, "theta.rds")
saveRDS(Q_fun, "Q_fun.rds")
saveRDS(ll, "ll.rds")
saveRDS(u, "u.rds")
}Here I show the values of our parameters (\(\theta_1\),\(\theta_2\), \(\lambda_1\), and \(\lambda_2\)), the Q function, and observed data log likelihood at each iteration of the algorithm implemented above.
if(!run_EM){
theta = readRDS( "theta.rds")
Q_fun = readRDS( "Q_fun.rds")
ll = readRDS("ll.rds")
u = readRDS("u.rds")
}
theta%>%
mutate(iteration = 1:nrow(theta))%>%
select(th1,th2, iteration)%>%
pivot_longer(1:2, names_to = "theta")%>%
ggplot(aes(x = iteration, y = value, color = theta))+
geom_line()+
theme_bw()theta%>%
mutate(iteration = 1:nrow(theta))%>%
select(l1,l2, iteration)%>%
pivot_longer(1:2, names_to = "lambda")%>%
ggplot(aes(x = iteration, y = value, color = lambda))+
geom_line()+
theme_bw()data.frame(Q = Q_fun)%>%
mutate(iteration = seq(1,length(Q_fun)))%>%
ggplot(aes(x = iteration, y = Q))+
geom_point(size = .5)+
labs(title = "Q function convergence")+
theme_bw() data.frame(obs_loglik = ll, iteration = seq(1,length(ll)))%>%
ggplot(aes(x = iteration, y = obs_loglik))+
geom_point(size = .5)+
theme_bw()+
labs(title = "Observed Data Log Likelihood Convergence", y = "Observed Log Likelihood")We can also try to look at what is happening when we maximize the Q function at each iteration. It is a function of 4 (or 3 if you consider take \(\theta_2=1-\theta_1\)) variables, so I plot the Q function with respect to each parameter to give a sense of how t.he algorithm is working (note we are constrained to \(\theta_2=1-\theta_1\))
y = sim_mixture
sims = FALSE
if(sims){
Q_theta1 = function(th1){
ui1 = function(x){exp(-l1)*l1^x*th1/(exp(-l1)*l1^x*th1+exp(-l2)*l2^x*th2)}
ui2 = function(x){exp(-l2)*l2^x*th2/(exp(-l1)*l1^x*th1+exp(-l2)*l2^x*th2)}
f1 = function(x){ exp(-l1)*l1^x/factorial(x)}
f2 = function(x){ exp(-l2)*l2^x/factorial(x)}
sum(sapply(y, FUN = ui1)*(log(th1)+log(sapply(y, f1)))+
sapply(y, FUN = ui2)*(log(th2)+log(sapply(y, f2))))
}
Q_theta2 = function(th2){
ui1 = function(x){exp(-l1)*l1^x*th1/(exp(-l1)*l1^x*th1+exp(-l2)*l2^x*th2)}
ui2 = function(x){exp(-l2)*l2^x*th2/(exp(-l1)*l1^x*th1+exp(-l2)*l2^x*th2)}
f1 = function(x){ exp(-l1)*l1^x/factorial(x)}
f2 = function(x){ exp(-l2)*l2^x/factorial(x)}
sum(sapply(y, FUN = ui1)*(log(th1)+log(sapply(y, f1)))+
sapply(y, FUN = ui2)*(log(th2)+log(sapply(y, f2))))
}
Q_lambda1 = function(l1){
ui1 = function(x){exp(-l1)*l1^x*th1/(exp(-l1)*l1^x*th1+exp(-l2)*l2^x*th2)}
ui2 = function(x){exp(-l2)*l2^x*th2/(exp(-l1)*l1^x*th1+exp(-l2)*l2^x*th2)}
f1 = function(x){ exp(-l1)*l1^x/factorial(x)}
f2 = function(x){ exp(-l2)*l2^x/factorial(x) }
sum(sapply(y, FUN = ui1)*(log(th1)+log(sapply(y, f1)))+
sapply(y, FUN = ui2)*(log(th2)+log(sapply(y, f2))))
}
Q_lambda2 = function(l2){
ui1 = function(x){exp(-l1)*l1^x*th1/(exp(-l1)*l1^x*th1+exp(-l2)*l2^x*th2)}
ui2 = function(x){exp(-l2)*l2^x*th2/(exp(-l1)*l1^x*th1+exp(-l2)*l2^x*th2)}
f1 = function(x){ exp(-l1)*l1^x/factorial(x)}
f2 = function(x){ exp(-l2)*l2^x/factorial(x) }
sum(sapply(y, FUN = ui1)*(log(th1)+log(sapply(y, f1)))+
sapply(y, FUN = ui2)*(log(th2)+log(sapply(y, f2))))
}
Qs = data.frame(iteration = NULL, parameter = NULL, x = NULL, Q = NULL)
for(i in 1:60){
params = theta[i,]
th1 = params$th1
th2 = params$th2
l1 = params$l1
l2 = params$l2
qth1 = c()
for(j in seq(0,1,by = .01)){
q = Q_theta1(j)
qth1 = c(qth1,q)
}
qth2 = c()
for(j in seq(0,1,by = .01)){
q = Q_theta2(j)
qth2 = c(qth2,q)
}
ql1 = c()
for(j in seq(0,15,by = .1)){
q = Q_lambda1(j)
ql1 = c(ql1,q)
}
ql2 = c()
for(j in seq(0,15,by = .1)){
q = Q_lambda2(j)
ql2 = c(ql2,q)
}
Qs = bind_rows(Qs, data.frame(iteration = i,
parameter = c(rep("theta1",length(qth1)),
rep("theta2",length(qth1)),
rep("lambda1",length(ql1)),
rep("lambda2",length(ql2))),
param_value = c(rep(th1,length(qth1)),
rep(th2,length(qth2)),
rep(l1,length(ql1)),
rep(l2,length(ql2))),
next_value = c(rep(theta[i+1,]$th1,length(qth1)),
rep(theta[i+1,]$th2,length(qth2)),
rep(theta[i+1,]$l1,length(ql1)),
rep(theta[i+1,]$l2,length(ql2))),
x = c(seq(0,1,by = .01),
seq(0,1,by = .01),
seq(0,15,by = .1),
seq(0,15,by = .1)),
Q = c(qth1,qth2,ql1,ql2)))
}
Qs.max = Qs%>%group_by(iteration, parameter)%>%slice(which.min(abs(x-param_value)))
anim_Q = Qs%>%
ggplot(aes(x = x, y = Q))+
geom_line()+
geom_point(data = Qs.max, aes(x = next_value, y = Q), color = "red")+
transition_states(iteration)+
facet_wrap(.~parameter, scales = "free_x",nrow = 2)+
theme_bw()+
labs(subtitle = "Iteration = {frame}", title = "Q Function Maximization")
saveRDS(anim_Q, "anim_Q.rds")
}
readRDS("anim_Q.rds")Finally, here is a visualization of the two Poisson distributions (and recitative proportions) that the algorithm estimates at each iteration.
# limit to first 50 iterations (doesn't change much after)
if(sims){
sim_data = data.frame(iteration = NULL, value = NULL)
for(i in 1:50){
params = theta[i,]
pdf1 = dpois(x=0:20, lambda=params$l1)*params$th1
pdf2 = dpois(x=0:20, lambda=params$l2)*params$th2
pdf =pdf1 + pdf2
sim_data = bind_rows(sim_data, data.frame(iteration = i, x = 1:21, combined = pdf, f1 = pdf1, f2 = pdf2, l1 = params$l1, l2 = params$l2))
}
saveRDS(sim_data, "sim_data.rds")
}
if(!sims){
sim_data = readRDS("sim_data.rds")}
sim_data%>%
full_join(data.frame(table(sim_mixture))%>%rename(x = sim_mixture, "Simulated Data"= Freq)%>%mutate(x = as.numeric(as.character(x)))%>%
mutate(`Simulated Data` = `Simulated Data`/sum(`Simulated Data`)), by = "x")%>%
filter(iteration<=50)%>%
full_join(u%>%rename(x = sample_point)%>%filter(iteration<=50), by = c("iteration","x"))%>%
na.omit()%>%
ggplot(aes(x = x))+
geom_histogram(aes(x = x, y=`Simulated Data`),binwidth = 1, fill = "grey", alpha = .6, stat = "identity")+
geom_vline(aes(xintercept = l1), color ="#238A8DFF", linetype = "dashed" )+
geom_vline(aes(xintercept = l2), color = "#FDE725FF", linetype = "dashed")+
geom_density(aes(y = combined ), color = "#440154FF",stat = "identity")+
geom_density(aes(y = f1), color = "#238A8DFF",stat = "identity")+
geom_density(aes(y = f2), color = "#FDE725FF",stat = "identity")+
scale_fill_viridis_c(begin = 0.4, end = 1, direction = -1)+
theme_bw()+
transition_states(iteration)+
labs(y = "Density (scaled)", title = "Iteration = {frame}", fill = "Posterior Probability")+
geom_point(shape = 21, aes(x = x, y = -.005, fill = posterior_prob),color = "#440154FF", size = 3, stroke = .3)